#!/usr/bin/env python3

import os
import random
import shutil
import subprocess
import sys
import traceback
import urllib.parse
import requests
from typing import List

from checker_utils import get_checker_class, basedir, run_checker, CHECKER_PACKAGES_PATH


def make_badge(text: str, color: str, name='exploits'):
    r = requests.get(f'https://img.shields.io/badge/{urllib.parse.quote(name)}-{urllib.parse.quote(text)}-{urllib.parse.quote(color)}')
    assert r.status_code == 200
    os.makedirs(os.path.join(basedir, 'public'), exist_ok=True)
    with open(os.path.join(basedir, 'public', f'ci-{name}.svg'), 'wb') as f:
        f.write(r.content)
    with open(os.path.join(basedir, '.nobadge'), 'w') as f:
        pass


def store_some_flags(target):
    cls_id = random.randint(1, 10)
    team_id = random.randint(1, 1000)
    cls = get_checker_class()
    checker = cls(cls_id)
    print('[OK]  Checker class has been created.')
    import gamelib
    team = gamelib.Team(team_id, os.urandom(6).hex(), target)
    flag_ids = [[] for _ in checker.flag_id_types]
    for tick in range(1, 50):
        checker.initialize_team(team)
        try:
            print(f'[...] Run store_flags(team, {tick})')
            status, msg = run_checker(checker.store_flags, team, tick)
            assert status == 'SUCCESS', f'Wrong status: {status} ("{msg}")'
            for i, _ in enumerate(checker.flag_id_types):
                flag_ids[i].append(checker.get_flag_id(team, tick, i))
        finally:
            checker.finalize_team(team)

    # build a command to retrieve all flags - if the exploit requires it
    scriptname = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'retrieve-flags.py')
    retrieve_cmd = f"cd '{os.getcwd()}' && python3 -u '{scriptname}' '{target}' {cls_id} {team_id}"

    return checker, flag_ids, retrieve_cmd


def test_exploit(exploit_file: str, target: str, checker, flag_ids: List[List[str]]) -> bool:
    print(f'\n\n--- Testing {exploit_file} ---')
    try:
        cmd = ['python3', exploit_file, target] + [','.join(ids) for ids in flag_ids]
        output = subprocess.check_output(
            cmd, cwd=os.path.join(basedir, 'exploits'), stderr=subprocess.STDOUT,
            timeout=60, start_new_session=True
        )
    except subprocess.CalledProcessError as e:
        print(f'Process failed with code {e.returncode}')
        print(f'Output: ')
        print(e.stdout.decode('utf-8', errors='ignore'))
        return False
    flags = checker.search_flags(output.decode('utf-8', errors='ignore'))
    if len(flags) < 1:
        print('Exploit did not return any flags. Output:\n' + output.decode('utf-8', errors='ignore'))
        return False
    valid_flags = [flag for flag in flags if checker.check_flag(flag)[0] is not None]
    if len(valid_flags) < len(flags):
        print(f'[WARNING] Exploit returned {len(flags) - len(valid_flags)} invalid flags (out of {len(flags)})')
    if len(valid_flags) < 1:
        print('Exploit did not return any valid flags. Output:\n' + output.decode('utf-8', errors='ignore'))
        return False
    return True


def main(target: str):
    # Check if any exploits are present
    exploit_files = [f for f in os.listdir(os.path.join(basedir, 'exploits')) if f.startswith('exploit') and f.endswith('.py')]
    if not exploit_files:
        print('No exploits found.')
        make_badge('none', 'yellow')
        return
    # Checker script is required to store flags in the service
    if not os.path.exists(os.path.join(basedir, 'checkers', 'config')):
        print('No checkerscript found. Create a file "config" in folder "checkers", content: "your-script-file.py:YourClassName".')
        make_badge('untested', 'yellow')
        return
    # 1. Add flags to the service
    try:
        checker, flag_ids, retrieve_cmd = store_some_flags(target)
        os.environ['CHECKER_RETRIEVE_CMD'] = retrieve_cmd
    except:
        print('Cannot test exploits, because no flags could be stored in the service.')
        make_badge('checker error', 'yellow')
        raise

    # 2. Check for each exploit if it can retrieve at least one flag
    good = 0
    for exploit_file in exploit_files:
        try:
            if test_exploit(exploit_file, target, checker, flag_ids):
                good += 1
        except:
            traceback.print_exc()

    # 3. Report result (as output and badge)
    print(f'[RESULT] {good} / {len(exploit_files)} exploits passed.')
    if good == len(exploit_files):
        make_badge(f'ok ({good})', 'brightgreen')
        return
    else:
        make_badge(f'{len(exploit_files) - good}/{len(exploit_files)} failed', 'red')
        raise Exception()


if __name__ == '__main__':
    target = sys.argv[1] if len(sys.argv) > 1 else '127.0.0.1'
    print(f'Checking exploits against "{target}" ...')
    try:
        main(target)
    finally:
        shutil.rmtree(CHECKER_PACKAGES_PATH, ignore_errors=True)
